Deep Learning: Mini Challenge
Yvo Keller, BSc Data Science, HS23
My goal with this notebook is to show the direct relation between the paper (Show, Attend and Tell) and it's implementation within this repository. I will try to explain the code as much as possible, nontheless recommend you to read the paper first.
Important: The code in this Notebook (with the expection of plotting captions and EDA), is for explanation purposes, and NOT fit to run in the Notebook. For running code, please refer to the repo's README and directly to the referenced scripts.
Resources:
In the implementation of the "Show, Attend and Tell" paper, a crucial step is preparing the dataset for training and evaluation.
Below, I'll discuss the key components of the generate_json_data script and their significance in the context of the project.
The script uses Karpathy's data splits, which are a standard way of dividing the dataset in image captioning tasks. This approach divides the dataset into training, validation, and test sets, allowing for a fair comparison with the results reported in the original "Show, Attend and Tell" paper.
Tokenization is the process of converting text into tokens, which are essentially numerical representations of words or characters. This script uses pre-tokenized captions (as indicated by sentence['tokens']). Special tokens such as <start>, <eos>, <unk>, and <pad> are added to the vocabulary. These tokens serve specific purposes:
<start>: Marks the beginning of a caption.<eos>: Signifies the end of a caption.<unk>: Represents unknown words not frequent enough in the dataset.<pad>: Used for padding shorter captions to a uniform length.The max_caption_length parameter, set to 25 by default, defines the maximum length of captions. This length is chosen based on the observation that most captions in the dataset are shorter than 25 tokens. Setting a maximum length is also beneficial for computational efficiency, as it ensures a uniform tensor size for batch processing during training. In case a caption exceeds this length, it is truncated. As start and eos tokens are added after, this results in a final sequence length of 27 for all captions.
The script processes and saves the data in JSON format, including image paths and corresponding tokenized captions. This preprocessing step speeds up the training process, as the data is already tokenized and split according to the required format. By loading preprocessed data, the model avoids the overhead of performing these operations during training iterations, leading to faster and more efficient training.
argparse to allow easy customization of input parameters like data paths and thresholds for word frequency.min_word_count (default 5).<unk> token is assigned.This script is located at generate_json_data.py. I wrote a second script, which performs the same preparation steps, just using the BERT tokenizer. That can be found at generate_json_data_bert.py
import argparse, json
from collections import Counter
def generate_json_data(split_path, data_path, max_captions_per_image, min_word_count, max_caption_length):
split = json.load(open(split_path, 'r'))
word_count = Counter()
train_img_paths = []
train_caption_tokens = []
validation_img_paths = []
validation_caption_tokens = []
test_img_paths = []
test_caption_tokens = []
max_length = 0
for img in split['images']:
caption_count = 0
for sentence in img['sentences']:
if caption_count < max_captions_per_image:
caption_count += 1
else:
break
try: # support flickr8k datasets.json that doesn't have subfolders
img['filepath']
except KeyError:
filepath_defined = False
img_path = f"{data_path}/imgs{'/' + img['filepath'] if filepath_defined else ''}/{img['filename']}"
if img['split'] == 'train':
train_img_paths.append(img_path)
train_caption_tokens.append(sentence['tokens'])
elif img['split'] == 'val':
validation_img_paths.append(img_path)
validation_caption_tokens.append(sentence['tokens'])
elif img['split'] == 'test':
test_img_paths.append(img_path)
test_caption_tokens.append(sentence['tokens'])
max_length = max(max_length, len(sentence['tokens']))
word_count.update(sentence['tokens'])
words = [word for word in word_count.keys() if word_count[word] >= min_word_count]
word_dict = {word: idx + 4 for idx, word in enumerate(words)}
word_dict['<start>'] = 0
word_dict['<eos>'] = 1
word_dict['<unk>'] = 2
word_dict['<pad>'] = 3
with open(data_path + '/word_dict.json', 'w') as f:
json.dump(word_dict, f)
max_length = min(max_length, max_caption_length)
train_captions = process_caption_tokens(train_caption_tokens, word_dict, max_length)
validation_captions = process_caption_tokens(validation_caption_tokens, word_dict, max_length)
test_captions = process_caption_tokens(test_caption_tokens, word_dict, max_length)
with open(data_path + '/train_img_paths.json', 'w') as f:
json.dump(train_img_paths, f)
with open(data_path + '/val_img_paths.json', 'w') as f:
json.dump(validation_img_paths, f)
with open(data_path + '/train_captions.json', 'w') as f:
json.dump(train_captions, f)
with open(data_path + '/val_captions.json', 'w') as f:
json.dump(validation_captions, f)
with open(data_path + '/test_img_paths.json', 'w') as f:
json.dump(test_img_paths, f)
with open(data_path + '/test_captions.json', 'w') as f:
json.dump(test_captions, f)
def process_caption_tokens(caption_tokens, word_dict, max_length):
captions = []
for tokens in caption_tokens:
tokens = tokens[:max_length]
token_idxs = [word_dict[token] if token in word_dict else word_dict['<unk>'] for token in tokens]
captions.append([word_dict['<start>']] + token_idxs + [word_dict['<eos>']] + [word_dict['<pad>']] * (max_length - len(tokens)))
return captions
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate json files')
parser.add_argument('--split-path', type=str, default='data/coco/dataset.json')
parser.add_argument('--data-path', type=str, default='data/coco')
parser.add_argument('--max-captions', type=int, default=5,
help='maximum number of captions per image')
parser.add_argument('--min-word-count', type=int, default=5,
help='minimum number of occurences of a word to be included in word dictionary')
parser.add_argument('--max-caption-length', type=int, default=25,
help='maximum number of tokens in a caption')
args = parser.parse_args()
generate_json_data(args.split_path, args.data_path, args.max_captions, args.min_word_count, args.max_caption_length)
The next important step is providing the prepared data as a DataLoader. I have focused on optimizing the performance of the data pipeline, which was crucial for efficient training.
In my DataLoader configuration, I've set pin_memory=True. This is a performance optimization in PyTorch that is particularly beneficial when using CUDA-enabled GPUs, or MPS (on M2) in my case. By enabling pinned memory, the DataLoader automatically places the fetched data Tensors in pinned memory, facilitating faster data transfer to the GPU.
In the ImageCaptionDataset class, I handle all preprocessing of images and captions before training starts. By performing image loading, transformation, and caption preprocessing just once and storing them in memory, I avoid redundant processing for each epoch or batch during training. This strategy is particularly advantageous for large datasets, as it significantly reduces I/O operations and processing time during the training loops.
Image Loading and Transformation: I load the image paths and captions from JSON files. The pil_loader function is used to load images, which are then transformed and converted to tensors.
BERT Embeddings Compatibility: The class can handle both standard and BERT embeddings, in line with the unique aspect of my model, where I experiment with both newly trained and pretrained BERT embeddings.
Fractional Dataset Utilization: I've added the capability to use only a fraction of the dataset, controlled by the fraction argument. This feature is proved really useful for quick iterations or debugging.
Efficient Data Storage: I store preprocessed image and caption tensors in a list (self.data), which enables fast data retrieval during training.
Handling Multiple Captions: The class is designed to efficiently provide all 5 captions that exist for the image during training, in addition to the one caption currently training on. This is relevant for calcuating the BLEU Score in the Valdiation and Testing Phase.
This script is located at dataset.py.
import json
import torch
from torch.utils.data import Dataset
from collections import defaultdict
from PIL import Image
import json
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class ImageCaptionDataset(Dataset):
def __init__(self, transform, data_path, split_type='train', fraction=1.0, bert=False):
super(ImageCaptionDataset, self).__init__()
self.transform = transform
# Load image paths and captions
img_paths = json.load(open(data_path + f'/{split_type}_img_paths.json', 'r'))
if bert:
captions = json.load(open(data_path + f'/{split_type}_captions_bert.json', 'r'))
else:
captions = json.load(open(data_path + f'/{split_type}_captions.json', 'r'))
# Reduce dataset size if fraction is not 1.0
if fraction != 1.0:
img_paths = img_paths[:int(len(img_paths) * fraction)]
captions = captions[:int(len(captions) * fraction)]
# Preprocess and store data
self.data = []
all_captions = defaultdict(list) # Store all captions for each image path
for img_path, caption in zip(img_paths, captions):
img = pil_loader(img_path)
if self.transform is not None:
img = self.transform(img)
self.data.append((torch.FloatTensor(img), torch.tensor(caption)))
all_captions[img_path].append(caption)
# Convert all_captions dictionary to a list matching the order of images
self.all_captions = [all_captions[path] for path in img_paths]
def __getitem__(self, index):
img_tensor, caption_tensor = self.data[index]
all_captions_tensor = torch.tensor(self.all_captions[index])
return img_tensor, caption_tensor, all_captions_tensor
def __len__(self):
return len(self.data)
In this section, I perform some explorative data analysis on the dataset that is later used for training. Thus, the dataset already includes all changes to the dataset (truncation to max caption length, min word count of 5 etc.) defined in the prior Data Loader section.
import json
import spacy
import nltk
from nltk.corpus import stopwords
import pandas as pd
from wordcloud import WordCloud
import matplotlib.pyplot as plt
nltk.download('stopwords')
!python -m spacy download en_core_web_sm
en_core_web_sm = spacy.load("en_core_web_sm")
DATA_PATH = 'data/flickr8k/'
SPLIT_TYPES = ['train', 'val', 'test']
captions = []
for split_type in SPLIT_TYPES:
captions.extend(json.load(open(DATA_PATH + f'/{split_type}_captions.json', 'r')))
print(f'Loaded {len(captions)} captions')
word_dict = json.load(open(DATA_PATH + '/word_dict.json', 'r'))
vocabulary_size = len(word_dict)
print(f'Vocabulary size: {vocabulary_size}')
# decode captions
decoded_captions = []
for caption in captions:
decoded_caption = []
for idx in caption:
if idx == 0:
continue
elif idx == 1:
break
else:
decoded_caption.append(list(word_dict.keys())[list(word_dict.values()).index(idx)])
decoded_captions.append(decoded_caption)
Loaded 40000 captions Vocabulary size: 2945
def generate_count_wordcloud(df:pd.DataFrame, top_n_words=30):
en_stopwords = set(stopwords.words('english'))
# create custom tokenizer that removes stopwords
def spacy_tokenizer(text):
tokens = en_core_web_sm(text)
return [token for token in tokens if token.text not in en_stopwords]
tokens = df.text.apply(spacy_tokenizer)
lowercase_tokens = [token.lower_ for doc in tokens for token in doc]
# create wordcloud
wordcloud = WordCloud(
width=800, height=400, background_color="white", max_words=top_n_words
).generate(" ".join(lowercase_tokens))
# show wordcloud
plt.figure(figsize=(12, 10))
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis("off")
plt.title(f'Word Cloud Flickr8k Image Captions')
plt.show()
# create dataframe with captions
decoded_captions_joined = [' '.join(caption) for caption in decoded_captions]
df = pd.DataFrame({'text': decoded_captions_joined})
print(f'Sample decoded caption: {decoded_captions_joined[0]}')
# generate wordcloud
generate_count_wordcloud(df)
Sample decoded caption: a black dog is running after a white dog in the snow
# Distribibution of caption lengths
caption_lengths = [len(caption) for caption in decoded_captions]
df_cl = pd.DataFrame({'caption_length': caption_lengths})
plt.figure()
plt.bar(df_cl.caption_length.value_counts().index, df_cl.caption_length.value_counts())
plt.title('Caption Length Distribution')
plt.xlabel('Caption Length')
plt.ylabel('Count')
plt.show()
As we can see, the average caption length is around 12 words. The word shows that topics like dog, person, woman, boy, little girl etc. are very common in the Flickr8k dataset. Thus, the model should be able to learn these topics very well.
Flickr8k contains a total of 8000 images with 5 captions each. The images are of different sizes, but are all scaled to 224x224 pixels before training.
Section 3.1.1 in the "Show, Attend and Tell" paper describes the encoder. The proposed model uses a CNN to extract a set of feature vectors, referred to as annotation vectors. This code implements that concept by using pre-trained models (VGG19, ResNet152, or DenseNet161), modifying them to exclude the final classification layers, and reshaping the output to form a set of feature vectors (our annotation vectors). Each vector represents different spacial parts of the image, which is key for the attention mechanism in the next stages of the model.
Specifically for VGG19, which I will be using in combination with the Flickr8k dataset, this means extracting the CNNs layers until just before the last pooling layer. This is done by using the features part of the VGG19 model, which is a Sequential object. We then remove the last layer from that object, which is the last pooling layer.
VGG(
(features): Sequential(
(31): ReLU(inplace=True)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace=True)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace=True) <<<<< this is the last layer we use
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
...
)
)
Here is the full Encoder class, which is located at encoder.py.
import torch.nn as nn
from torchvision.models import densenet161, resnet152, vgg19
from torchvision.models import VGG19_Weights
class Encoder(nn.Module):
"""
Encoder network for image feature extraction, follows section 3.1.1 of the paper
"""
def __init__(self, network='vgg19'):
super(Encoder, self).__init__()
self.network = network
# Selection of pre-trained CNNs for feature extraction
if network == 'resnet152':
self.net = resnet152(pretrained=True)
# Removing the final fully connected layers of ResNet152
self.net = nn.Sequential(*list(self.net.children())[:-2])
self.dim = 2048 # Dimension of feature vectors for ResNet152
elif network == 'densenet161':
self.net = densenet161(pretrained=True)
# Removing the final layers of DenseNet161
self.net = nn.Sequential(*list(list(self.net.children())[0])[:-1])
self.dim = 1920 # Dimension of feature vectors for DenseNet161
else:
self.net = vgg19(weights=VGG19_Weights.DEFAULT)
# Using features from VGG19, excluding the last pooling layer
self.net = nn.Sequential(*list(self.net.features.children())[:-1])
self.dim = 512 # Dimension of feature vectors for VGG19
# Freezing the weights of the pre-trained CNN
for params in self.net.parameters():
params.requires_grad = False
def forward(self, x):
x = self.net(x)
# These steps correspond to the extraction of annotation vectors (a = {a1,...,aL}) as described in Section 3.1.1 of the paper.
# 1. Change the order from (BS, C, H, W) to (BS, H, W, C) in prep for reshaping
x = x.permute(0, 2, 3, 1)
# 2. Reshape to [BS, num_spatial_features, C], the -1 effectively flattens the height and width dimensions into a single dimension
x = x.view(x.size(0), -1, x.size(-1))
return x
Let's move forward to Section 3.1.2, explaining the Decoder.
On a high level, it works as follows:
tf=True.Provided below is the full implementation of the Decoder. I will break the code down into smaller chunks to explore it in more detail, based on the paper's description of the Decoder.
The full code is located at decoder.py.
import torch
import torch.nn as nn
from attention import Attention
class Decoder(nn.Module):
def __init__(self, vocabulary_size, encoder_dim, tf=False, ado=False, bert=False, attention=False):
super(Decoder, self).__init__()
self.use_tf = tf
self.use_advanced_deep_output = ado
self.use_bert = bert
self.use_attention = attention
# Initializing parameters
self.encoder_dim = encoder_dim
# Embeddings
if bert == True:
from transformers import BertModel, BertTokenizer
self.bert_model = BertModel.from_pretrained('bert-base-uncased')
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.vocabulary_size = self.bert_model.config.vocab_size
self.embedding_size = self.bert_model.config.hidden_size # 768
# Embedding layer using BERT's embeddings
self.embedding = self.bert_model.get_input_embeddings()
# Freeze the BERT embeddings
for param in self.embedding.parameters():
param.requires_grad = False
# Delete the BERT model to save memory (and checkpoint size)
del self.bert_model
else:
self.vocabulary_size = vocabulary_size
self.embedding_size = 512
self.embedding = nn.Embedding(self.vocabulary_size, self.embedding_size) # Embedding layer for input words
# Initial LSTM cell state generators
self.init_h = nn.Linear(encoder_dim, self.embedding_size) # For hidden state
self.init_c = nn.Linear(encoder_dim, self.embedding_size) # For cell state
self.tanh = nn.Tanh()
# Attention mechanism related layers
self.f_beta = nn.Linear(self.embedding_size, encoder_dim) # Gating scalar in attention mechanism
self.sigmoid = nn.Sigmoid()
# Attention and LSTM components
self.attention = Attention(encoder_dim, self.embedding_size) # Attention network
self.lstm = nn.LSTMCell(self.embedding_size + encoder_dim, self.embedding_size) # LSTM cell
# Deep output layers
if self.use_advanced_deep_output:
# Advanced DO: Layers for transforming LSTM state, context vector and embedding for DO-RNN
hidden_dim, intermediate_dim = self.embedding_size, self.embedding_size
self.f_h = nn.Linear(hidden_dim, intermediate_dim) # Transforms LSTM hidden state
self.f_z = nn.Linear(encoder_dim, intermediate_dim) # Transforms context vector
self.f_out = nn.Linear(intermediate_dim, self.vocabulary_size) # Transforms the combined vector (sum of embedding, LSTM state, and context vector) to voc_size
self.relu = nn.ReLU() # Activation function
self.dropout = nn.Dropout()
# Simple DO: Layer for transforming LSTM state to vocabulary
self.deep_output = nn.Linear(self.embedding_size, self.vocabulary_size) # Maps LSTM outputs to vocabulary
self.dropout = nn.Dropout()
def forward(self, img_features, captions):
# Forward pass of the decoder
batch_size = img_features.size(0)
# Initialize LSTM state
h, c = self.get_init_lstm_state(img_features)
# Teacher forcing setup
max_timespan = max([len(caption) for caption in captions]) - 1
if self.use_bert:
start_token = torch.full((batch_size, 1), self.tokenizer.cls_token_id).long().to(mps_device)
else:
start_token = torch.zeros(batch_size, 1).long().to(mps_device)
# Convert caption tokens to their embeddings
if self.use_tf:
caption_embedding = self.embedding(captions)
else:
previous_predicted_token_embedding = self.embedding(start_token)
# Preparing to store predictions and attention weights
preds = torch.zeros(batch_size, max_timespan, self.vocabulary_size).to(mps_device) # [BATCH_SIZE, TIME_STEPS, VOC_SIZE]
alphas = torch.zeros(batch_size, max_timespan, img_features.size(1)).to(mps_device) # [BATCH_SIZE, TIME_STEPS, NUM_SPATIAL_FEATURES]
# Generating captions
for t in range(max_timespan):
if self.use_attention:
context, alpha = self.attention(img_features, h) # Compute context vector via attention
gate = self.sigmoid(self.f_beta(h)) # Gating scalar for context
gated_context = gate * context # Apply gate to context
else:
# If not using attention, treat all parts of the image equally
alpha = torch.full((batch_size, img_features.size(1)), 1.0 / img_features.size(1), device=mps_device) # Uniform attention
context = img_features.mean(dim=1) # Simply take the mean of the image features
gated_context = context # No gating applied
# Prepare LSTM input
if self.use_tf:
lstm_input = torch.cat((caption_embedding[:, t], gated_context), dim=1) # current embedding + context vector as input vector
else:
previous_predicted_token_embedding = previous_predicted_token_embedding.squeeze(1) if previous_predicted_token_embedding.dim() == 3 else previous_predicted_token_embedding
lstm_input = torch.cat((previous_predicted_token_embedding, gated_context), dim=1)
# LSTM forward pass
h, c = self.lstm(lstm_input, (h, c))
# Generate word prediction
if self.use_advanced_deep_output:
# NOTE: could explore alternative positions for dropout
if self.use_tf:
output = self.advanced_deep_output(self.dropout(h), context, caption_embedding[:, t])
else:
output = self.advanced_deep_output(self.dropout(h), context, previous_predicted_token_embedding)
else:
output = self.deep_output(self.dropout(h))
preds[:, t] = output # Store predictions
alphas[:, t] = alpha # Store attention weights
# Prepare next input word
if not self.use_tf:
predicted_token_idxs = output.max(1)[1].reshape(batch_size, 1) # output.max(1)[1] = extract index: [1] of the token with the highest probability: max(1)
previous_predicted_token_embedding = self.embedding(predicted_token_idxs)
return preds, alphas
def get_init_lstm_state(self, img_features):
# Initializing LSTM state based on image features
avg_features = img_features.mean(dim=1)
c = self.init_c(avg_features) # Cell state
c = self.tanh(c)
h = self.init_h(avg_features) # Hidden state
h = self.tanh(h)
return h, c
def advanced_deep_output(self, h, context, current_embedding):
# Combine the LSTM state and context vector
h_transformed = self.relu(self.f_h(h))
z_transformed = self.relu(self.f_z(context))
# Sum the transformed vectors with the embedding
combined = h_transformed + z_transformed + current_embedding
# Transform the combined vector & compute the output word probability
return self.relu(self.f_out(combined))
At the heart of the Show, Attend and Tell model is the attention mechanism. The attention mechanism is used to focus on different parts of the image while generating the caption. The attention mechanism is implemented as a separate module, which is used by the Decoder at each time step.
There are two main types of attention mechanisms: soft attention and hard attention. Soft attention is differentiable and allows for end-to-end training, while hard attention is non-differentiable and requires reinforcement learning to train. I want to break down both types of attention mechanisms theoretically and then show how they are implemented in the code.
Stochastic "Hard" Attention is an approach where the model discretely chooses specific regions (or locations, i.e. one annotation vector) in an image to focus on at each step of generating a caption. This contrasts with "Soft" Attention, where the model considers all regions but with varying degrees of focus.
$ s_{t, i} = 1 $ if the $ i $-th location is chosen at time $ t $, out of $ L $ total locations.
Significance: Represents the model's decision on where to focus in the image when generating the $ t $-th word in the caption as a one-hot vector.
Significance: Computes the context vector as the feature vector of the selected image region. Only the chosen region contributes to the context at each step.
Significance: Models the attention decision as a random variable, following a Multinoulli distribution. The attention weights $ \{\alpha_i\} $ determine the probability of focusing on each region.
where...
Significance: $ L_s $ serves as a computationally feasible approximation to the true log-likelihood of generating the correct caption. It is about finding the best possible set of attention decisions (where to focus in the image at each step) to maximize the probability of correctly generating the caption sequence. It's optimized during training to improve captioning accuracy.
where...
Significance: Provides a practical method to approximate the gradient of $ L_s $ for model optimization, as direct computation is infeasible due to the stochastic nature of hard attention.
Moving Average Baseline $$ b_k = 0.9 \times b_{k-1} + 0.1 \times \log p(\mathbf{y} | \tilde{s}_k, \mathbf{a}) $$
where...
$ b_k $ represents the moving average baseline at the $ k $-th mini-batch during training.
The formula for $ b_k $ involves an exponential decay component, which is a method commonly used to calculate a moving average that gives more weight to recent observations. In this case, the decay is controlled by the coefficient $ 0.9 $. This coefficient multiplies the previous baseline $ b_{k-1} $, effectively reducing its influence over time.
Significance: Reduces the variance in the Monte Carlo estimator of the gradient, stabilizing training.
Entropy Regularization $$ \lambda_e \frac{\partial H[\tilde{s}^n]}{\partial W} $$
where...
$ H[\tilde{s}^n] $ is the entropy of the sampled attention sequence $ \tilde{s}^n $. By adding the entropy of the attention distribution to the objective function, the model is encouraged to maintain a degree of uncertainty in its attention decisions. This encouragement for higher entropy effectively promotes exploration in the model's attention mechanism. Instead of always focusing on the same regions for similar images or features, the model is nudged to explore other potentially informative regions as well.
$ \lambda_e $ is a hyperparameter controlling the strength of the entropy regularization.
Significance: Encourages exploration in attention decisions, further reducing variance and improving model robustness. A model that explores more diverse attention strategies is less likely to get stuck in local optima and can generalize better.
where...
Significance: Combines all elements (gradient approximation, baseline, and entropy regularization) into a single learning rule for training the model with hard attention.
This approach aligns with the REINFORCE rule from reinforcement learning, treating the sequence of attention decisions as actions with associated rewards based on the log likelihood of the generated caption.
Unlike stochastic hard attention, which involves random sampling (where the model discretely chooses specific regions to focus on), soft attention deterministically calculates a weighted sum of all parts of the input, allowing for straightforward optimization and learning. Thus, in soft attention, the model considers all regions (or locations, i.e. all annotation vectors) in an image at each step of generating a caption, but with varying degrees of focus.
where...
The weights $ \alpha_{t, i} $ are the attention probabilities for each region at time step $ t $
$ L $ is the total number of regions.
Explanation: This formula represents the expected context vector as a weighted sum of all annotation vectors $ \mathbf{a}_i $ from the image.
Significance: It provides a 'soft' focus by blending information from all parts of the image, with more emphasis on the areas deemed most relevant by the model.
Deterministic Soft Attention offers a practical and efficient method for implementing attention mechanisms in neural networks, especially for tasks like image captioning. By calculating a weighted sum of input features and avoiding the complexity of stochastic sampling, it facilitates smooth and differentiable models that are amenable to standard training techniques. This approach enables the model to effectively focus on relevant parts of the input while maintaining computational tractability and ease of training.
Enough theory, let's apply this knowledge to the code. The attention mechanism is implemented as a separate module used by the decoder at each time step. I use the deterministic soft attention mechanism described above. The reason for this is that implementing stochastic hard attention is dramatically more complex and would require reinforcement learning to train the model, which is beyond the scope of the deep learning course for which I am implementing this project. It is worth noting, however, that Stochastic Hard Attention performed slightly better than Deterministic Soft Attention in the original paper, as measured by the BLEU score.
import torch
import torch.nn as nn
class Attention(nn.Module):
def __init__(self, encoder_dim):
super(Attention, self).__init__()
self.U = nn.Linear(512, 512)
self.W = nn.Linear(encoder_dim, 512)
self.v = nn.Linear(512, 1)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(1)
def forward(self, img_features, hidden_state):
U_h = self.U(hidden_state).unsqueeze(1)
W_s = self.W(img_features)
att = self.tanh(W_s + U_h)
e = self.v(att).squeeze(2)
alpha = self.softmax(e)
context = (img_features * alpha.unsqueeze(2)).sum(1)
return context, alpha
Formula: $ \mathbb{E}_{p(s_t \mid \mathbf{a})}[\hat{\mathbf{z}}_t] = \sum_{i=1}^L \alpha_{t, i} \mathbf{a}_i $
Code Implementation (in Attention class):
Attention Weights Calculation:
U_h = self.U(hidden_state).unsqueeze(1)
W_s = self.W(img_features)
att = self.tanh(W_s + U_h)
e = self.v(att).squeeze(2)
alpha = self.softmax(e)
Here, U_h and W_s are the transformed hidden state and image features, respectively. alpha is the attention probability for each region in the image.
Context Vector Calculation:
context = (img_features * alpha.unsqueeze(2)).sum(1)
This line then computes the weighted sum of the image features based on the attention weights, resulting in the context vector.
In an effort to make the implementation more closely align with the paper and add new features, I made many changes to the base implementation.
I will highlight the most important ones here, before going into more detail below:
A non comprehensive list of other changes I made to the implementation:
utils.py to not be artificially inflated by padding tokens being counted as correct.<eos> token occurence, instead of letting the model predict to the max sequence length and basing the BLEU score on that.fraction argument in the ImageCaptionDataset class. This is useful for quick iterations or debugging.DataLoader to speed up data transfer to the GPU.calculate_caption_lengths function in utils.py, as it was a significant bottleneck which pytorch bottleneck profiler highlighted.As described towards the end of section 3.1.2, Show, Attend and Tell utilizes a deep output layer (Pascanu et al., 2014) to compute the output word probability given given the current state of the LSTM, the context vector from the attention mechanism, and the previously generated word.
Let's break down this formula and map its components to the code:
$$ p\left(\mathbf{y}_t \mid \mathbf{a}, \mathbf{y}_1^{t-1}\right) \propto \exp \left(\mathbf{L}_o\left(\mathbf{E} \mathbf{y}_{t-1}+\mathbf{L}_h \mathbf{h}_t+\mathbf{L}_z \hat{\mathbf{z}}_t\right)\right) $$Where p of $ \mathbf{y}_t $ is the probability of the output word $ y $ at time $ _t $ given the image features $ \mathbf{a} $ and the previously generated words $ \mathbf{y}_1^{t-1} $.
In this formula:
Now let's map this to the Decoder's code in forward():
Embedding of the Previous Word ($ \mathbf{E} \mathbf{y}_{t-1} $): This is done using the self.embedding layer in the code.
embedding = self.embedding(prev_words)
Hidden State of the LSTM ($ \mathbf{h}_t $): The h variable in the code represents the hidden state of the LSTM at each time step.
h, c = self.lstm(lstm_input, (h, c))
Context Vector ($ \hat{\mathbf{z}}_t $): The context vector is computed by the attention mechanism in the self.attention layer.
context, alpha = self.attention(img_features, h)
Combining and Transforming for Output Prediction: The output word probability is computed by combining these elements and applying the learned weight matrices. Here, this operation is currently condensed into one self.deep_output layer transforming just the hidden state $ \mathbf{h}_t $. In a more complex or literal implementation of the Deep Output Layer as layed out in Show, Attend and Tell, you would expect to see multiple such layers, each followed by a non-linear activation function.
output = self.deep_output(self.dropout(h))
As we saw above, the paper describes the deep-output RNN as having multiple layers, each followed by a non-linear activation function. The implementation by AaronCCWong only had one layer transforming the hidden state of the LSTM. Therfore, I implemented the deep output as described in the paper, with multiple layers and non-linear activations. This can be enabled by setting use_advanced_deep_output=True flag when training the model.
Where:
class Decoder(nn.Module):
def __init__(self, vocabulary_size, encoder_dim, tf=False, ado=False):
# ...
# Deep output layers
# Advanced DO: Layers for transforming LSTM state, context vector and embedding for DO-RNN
if self.use_advanced_deep_output:
hidden_dim, intermediate_dim = self.embedding_size, self.embedding_size
self.f_h = nn.Linear(hidden_dim, intermediate_dim) # Transforms LSTM hidden state
self.f_z = nn.Linear(encoder_dim, intermediate_dim) # Transforms context vector
self.f_out = nn.Linear(intermediate_dim, self.vocabulary_size) # Transforms combined vector (sum of embedding, LSTM state, and context vector) to voc_size
self.relu = nn.ReLU() # Activation function
self.dropout = nn.Dropout()
# Simple DO: Layer for transforming LSTM state to vocabulary
self.deep_output = nn.Linear(self.embedding_size, self.vocabulary_size) # Maps LSTM outputs to vocabulary
self.dropout = nn.Dropout()
# ...
def forward(self, img_features, captions):
# ...
for t in range(max_timespan):
# ...
# Generate word prediction
if self.use_advanced_deep_output:
if self.use_tf:
output = self.advanced_deep_output(self.dropout(h), context, caption_embedding[:, t])
else:
output = self.advanced_deep_output(self.dropout(h), context, previous_predicted_token_embedding)
else:
output = self.deep_output(self.dropout(h))
# ...
def advanced_deep_output(self, h, context, current_embedding):
# Combine the LSTM state and context vector
h_transformed = self.relu(self.f_h(h))
z_transformed = self.relu(self.f_z(context))
# Sum the transformed vectors with the embedding
combined = h_transformed + z_transformed + current_embedding
# Transform the combined vector & compute the output word probability
return self.relu(self.f_out(combined))
In my implementation of the "Show, Attend and Tell" model, I explored the use of both standard embeddings and BERT embeddings in the decoder module. BERT, being a transformer-based model, provides contextually rich embeddings compared to standard word embeddings. Integrating BERT embeddings required modifications to the decoder architecture and data processing pipeline (see generate_json_data_bert.py and decoder.py for details).
max_caption_length to 30. This adjustment ensures that the increased token count, a consequence of BERT's finer granularity in tokenization, doesn't lead to excessive truncation of the captions.[CLS] (used at the beginning of a text to signify classification tasks) and [SEP] (used as a separator, e.g., between sentences). In my preprocessing routine, I aligned the tokenizer's bos_token (beginning of sequence) and eos_token (end of sequence) with BERT's [CLS] and [SEP] tokens respectively.bert=True, I utilized the BertModel and BertTokenizer from the transformers library. The model used is bert-base-uncased, which provides a good balance between performance and computational efficiency. Its uncased nature also makes it suitable for the flickr8k dataset, which contains only lower-case words in the captions.embedding_size to BERT's hidden size (768) and vocabulary_size to BERT's vocabulary size. This ensures compatibility with the pre-trained BERT model.h and c) are initialized using linear transformations of the encoder's output, followed by a tanh activation. This remains unchanged for BERT integration.cls_token_id. This differs from the standard approach where a <start> token is used.use_advanced_deep_output is true, the model employs an enhanced deep output layer, which combines the transformed LSTM state, context vector, and current embedding to predict the next word. This mechanism is independent of the embedding type but requires careful dimensionality alignment (Standard embeddings are size 512, BERT 768).A quick introduction to the metrics used to evaluate the performance of the model.
BLEU (Bilingual Evaluation Understudy) Score is a widely used metric for evaluating the quality of text which has been machine-translated from one language to another. In the context of image captioning, it's adapted to assess the quality of generated captions compared to a set of reference captions.
The BLEU score is calculated based on n-gram precision. For each n-gram size (up to a predefined limit, typically 4), it compares the n-grams of the generated text with the n-grams of the reference texts, counting the number of matches. These matches are then adjusted by a brevity penalty to penalize overly short predictions. The formula for BLEU score is:
$$ \text{BLEU} = \text{BP} \cdot \exp\left(\sum_{n=1}^{N} w_n \log p_n\right) $$Where:
BLEU score provides a quantitative measure of the similarity between the machine-generated text and human-generated reference texts. Higher BLEU scores indicate better alignment with the reference captions, suggesting higher quality translations or captions. Scores between 0.6 and 0.7 are considered the best one can achieve (https://towardsdatascience.com).
Top-N accuracy is a performance metric used to evaluate classification models, including those in image captioning where the task can be viewed as predicting the next word in a sequence.
Top-N accuracy is computed as follows:
$$ \text{Top-N Accuracy} = \frac{\text{Number of times the correct label is in the top N predictions}}{\text{Total number of predictions}} $$For my deep learning experiment, I focused on training four distinct variants of the "Show, Attend and Tell" model to investigate the impact of different configurations, particularly the presence of attention mechanisms and the use of BERT embeddings. The experiments were conducted using a controlled setup, with consistent hyperparameters and data splits.
plain_att): This variant included attention, teacher forcing, and advanced deep output but did not utilize BERT embeddings.plain_noatt): Similar to the first variant but without the attention mechanism.bert_att): This variant employed both BERT embeddings and the attention mechanism, alongside teacher forcing and advanced deep output.bert_noatt): This setup utilized BERT embeddings but without the attention mechanism.fraction=1.0), and the default network for the encoder was 'vgg19'.--tf), advanced deep output (--ado), BERT embeddings (--bert), and attention mechanism (--attention) were used to toggle these features.The primary objective of these experiments was to assess how different configurations (attention mechanism and BERT embeddings) influence the model's performance. By training these four variants under consistent conditions, I aimed to draw meaningful comparisons and insights into the individual and combined effects of attention and BERT embeddings on the image captioning task.
Below, I delve into the key aspects of the training, including the distinctions between train, validation, and test modes, and other critical elements of the training pipeline.
eval mode (to use the pre-trained model without modification), and the decoder is set to train mode. This setup enables the model to learn from the training dataset. The key operations in this mode include processing images and captions, computing the forward pass, calculating loss (including attention regularization), performing backpropagation, and updating model parameters.eval mode, ensuring no updates to the model parameters occur. The validation mode is crucial for assessing the model's performance on unseen data without the influence of training dynamics.eval mode. The test mode is used for final evaluation of the model's performance, on the test dataset that the model has not seen during training or validation.bert=True, the BERT tokenizer and model are loaded, and the vocabulary size is set accordingly. Otherwise, the custom word dictionary is used.Below are the most relevant parts of the training script (some code removed for brevity). The full training script is available at train.py.
class EvalMode(Enum):
VALIDATION = 'val'
TEST = 'test'
def main(args):
set_seed(args.seed)
wandb.init(project='show-attend-and-tell', entity='yvokeller', config=args)
# ...
encoder = Encoder(args.network)
decoder = Decoder(vocabulary_size, encoder.dim, tf=args.tf, ado=args.ado, bert=args.bert, attention=args.attention)
optimizer = optim.Adam(decoder.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, args.step_size)
cross_entropy_loss = nn.CrossEntropyLoss().to(mps_device)
train_loader = DataLoader(
ImageCaptionDataset(data_transforms, args.data, fraction=args.fraction, bert=args.bert, split_type='train'),
batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(
ImageCaptionDataset(data_transforms, args.data, fraction=args.fraction, bert=args.bert, split_type='val'),
batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(
ImageCaptionDataset(data_transforms, args.data, fraction=args.fraction, bert=args.bert, split_type='test'),
batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
print(f'Starting training with {args}')
for epoch in range(1, args.epochs + 1):
train(epoch, encoder, decoder, optimizer, cross_entropy_loss,
train_loader, word_dict, args.alpha_c, args.log_interval, bert=args.bert, tokenizer=bert_tokenizer, args=args)
validate(epoch, encoder, decoder, cross_entropy_loss, val_loader,
word_dict, args.alpha_c, args.log_interval, bert=args.bert, tokenizer=bert_tokenizer)
scheduler.step()
if args.perform_test == True:
test(epoch, encoder, decoder, cross_entropy_loss, test_loader,
word_dict, args.alpha_c, args.log_interval, bert=args.bert, tokenizer=bert_tokenizer)
wandb.finish()
def train(epoch, encoder, decoder, optimizer, cross_entropy_loss, data_loader, word_dict, alpha_c, log_interval, bert=False, tokenizer=None, args={}):
print(f"Epoch {epoch} - Starting train")
encoder.eval()
decoder.train()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
for batch_idx, (imgs, captions, _) in enumerate(data_loader):
imgs, captions = Variable(imgs).to(mps_device), Variable(captions).to(mps_device)
img_features = encoder(imgs)
optimizer.zero_grad()
preds, alphas = decoder(img_features, captions)
targets = captions[:, 1:] # skip <start> token for loss calculation
# Calculate accuracy
padding_idx = word_dict['<pad>'] if bert == False else tokenizer.pad_token_id
acc1 = sequence_accuracy(preds, targets, 1, ignore_index=padding_idx, tokenizer=tokenizer)
acc5 = sequence_accuracy(preds, targets, 5, ignore_index=padding_idx, tokenizer=tokenizer)
# Calculate loss
packed_targets = pack_padded_sequence(targets, [len(tar) - 1 for tar in targets], batch_first=True)[0]
packed_preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds], batch_first=True)[0]
# encourage total attention (alphas) to be close to 1, thus penalize when sum is far from 1
att_regularization = alpha_c * ((1 - alphas.sum(1)) ** 2).mean()
loss = cross_entropy_loss(packed_preds, packed_targets)
loss += att_regularization # pytorch autograd will calculate gradients for both loss and att_regularization
loss.backward()
optimizer.step()
if bert == True:
total_caption_length = calculate_caption_lengths(...)
else:
total_caption_length = calculate_caption_lengths(...)
losses.update(loss.item(), total_caption_length)
top1.update(acc1, total_caption_length)
top5.update(acc5, total_caption_length)
if batch_idx % log_interval == 0:
print(f'Train Batch: [{batch_idx}/{len(data_loader)}]\t'
f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
f'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
f'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})')
wandb.log({
'train_loss': losses.avg, 'train_top1_acc': top1.avg, 'train_top5_acc': top5.avg, 'epoch': epoch,
'train_loss_raw': losses.val, 'train_top1_acc_raw': top1.val, 'train_top5_acc_raw': top5.val
})
def validate(epoch, *args, **kwargs):
print(f"Epoch {epoch} - Starting validation")
return run_evaluation(epoch, *args, mode=EvalMode.VALIDATION, **kwargs)
def test(epoch, *args, **kwargs):
print(f"Epoch {epoch} - Starting test")
return run_evaluation(epoch, *args, mode=EvalMode.TEST, **kwargs)
def run_evaluation(epoch, encoder, decoder, cross_entropy_loss, data_loader, word_dict, alpha_c, log_interval, bert=False, tokenizer=None, mode=EvalMode.VALIDATION):
encoder.eval()
decoder.eval()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
decoded_captions = [] # list of single assigned caption for each image
decoded_all_captions = [] # list of list of all captions present in dataset for each image, thus captions may repeat in different lists
decoded_hypotheses = [] # list of single predicted caption for each image
predictions_table = wandb.Table(columns=["epoch", "mode", "target_caption", "pred_caption"])
with torch.no_grad():
logged_attention_visualizations_count = 0
for batch_idx, (imgs, captions, all_captions) in enumerate(data_loader):
imgs, captions = Variable(imgs).to(mps_device), Variable(captions).to(mps_device)
img_features = encoder(imgs)
preds, alphas = decoder(img_features, captions)
targets = captions[:, 1:]
# Calculate accuracy
padding_idx = word_dict['<pad>'] if bert == False else tokenizer.pad_token_id
acc1 = sequence_accuracy(preds, targets, 1, ignore_index=padding_idx, tokenizer=tokenizer)
acc5 = sequence_accuracy(preds, targets, 5, ignore_index=padding_idx, tokenizer=tokenizer)
# Calculate loss
packed_targets = pack_padded_sequence(targets, [len(tar) - 1 for tar in targets], batch_first=True)[0]
packed_preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds], batch_first=True)[0]
att_regularization = alpha_c * ((1 - alphas.sum(1)) ** 2).mean()
loss = cross_entropy_loss(packed_preds, packed_targets)
loss += att_regularization
if bert == True:
total_caption_length = calculate_caption_lengths(...)
else:
total_caption_length = calculate_caption_lengths(...)
losses.update(loss.item(), total_caption_length)
top1.update(acc1, total_caption_length)
top5.update(acc5, total_caption_length)
if bert == True:
# ... DECODE CAPTIONS BERT
else:
# ... DECODE CAPTIONS STANDARD
if batch_idx % log_interval == 0:
print(f'{mode} Batch: [{batch_idx}/{len(data_loader)}]\t'
f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
f'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
f'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})')
if mode == EvalMode.TEST:
# Calculate the start index for the current batch
batch_start_idx = batch_idx * len(imgs)
# Log the attention visualizations
for img_idx, img_tensor in enumerate(imgs):
# Skip attention visualization if already logged enough
if logged_attention_visualizations_count >= 50:
break
logged_attention_visualizations_count += 1
# Calculate the global index for decoded_hypotheses and decoded_captions lists
global_caption_idx = batch_start_idx + img_idx
if len(decoded_hypotheses[global_caption_idx]) == 0:
print(f'No caption for image {global_caption_idx}, skipping attention visualization')
break
log_attention_visualization_plot(img_tensor, alphas, decoded_hypotheses, decoded_captions, batch_idx, img_idx, global_caption_idx, encoder)
bleu_1 = corpus_bleu(decoded_all_captions, decoded_hypotheses, weights=(1, 0, 0, 0))
bleu_2 = corpus_bleu(decoded_all_captions, decoded_hypotheses, weights=(0.5, 0.5, 0, 0))
bleu_3 = corpus_bleu(decoded_all_captions, decoded_hypotheses, weights=(0.33, 0.33, 0.33, 0))
bleu_4 = corpus_bleu(decoded_all_captions, decoded_hypotheses)
wandb.log({
'epoch': epoch,
f'{epoch}_{mode.value}_caption_predictions': predictions_table,
f'{mode.value}_loss': losses.avg, f'{mode.value}_top1_acc': top1.avg, f'{mode.value}_top5_acc': top5.avg,
f'{mode.value}_loss_raw': losses.val, f'{mode.value}_top1_acc_raw': top1.val, f'{mode.value}_top5_acc_raw': top5.val,
f'{mode.value}_bleu1': bleu_1, f'{mode.value}_bleu2': bleu_2, f'{mode.value}_bleu3': bleu_3, f'{mode.value}_bleu4': bleu_4,
})
print(f'{mode} Epoch: {epoch}\t'
f'BLEU-1 ({bleu_1})\t'
f'BLEU-2 ({bleu_2})\t'
f'BLEU-3 ({bleu_3})\t'
f'BLEU-4 ({bleu_4})\t')
In this section, I explore the results for the following models, trained as described in chapter Experiment Setup:
plain-att-173): This variant included attention, teacher forcing, and advanced deep output but did not utilize BERT embeddings.plain-noatt-175): Similar to the first variant but without the attention mechanism.bert-att-176): This variant employed both BERT embeddings and the attention mechanism, alongside teacher forcing and advanced deep output.bert-noatt-177): Similar to the previous variant but without the attention mechanism.The detailed, intercative plots are available here: https://wandb.ai/yvokeller/show-attend-and-tell/reports/Image-Captioning-with-Attention--Vmlldzo2NDQ4Nzc5
bert-) have a similar loss trajectory to those without (plain-). This suggests that the introduction of BERT embeddings did not drastically change the loss landscape for the model, but ends up with a slightly higher loss (~0.25 points).-att-) also don't show a significantly different pattern in loss compared to those without (-noatt-). This suggests that the attention mechanism also does not drastically change the loss landscape for the model. This is a bit surprising, as the attention mechanism is a key component of the model and was shown to improve performance in the original paper. I take this as a signal that something could be wrong with my implementation of the attention mechanism.
Top-1 Accuracy:
The top-1 accuracy for both training and validation phases across different configurations. The metric represents the percentage of times the model's highest probability prediction for the next word in the caption was indeed the correct word.
Top-5 Accuracy:
The top-5 accuracy plot reflects how often the correct next word appears within the top five predictions of the model.
The plots show the BLEU-1 and BLEU-4 scores calculated during validation and test phases for the model variants. BLEU-1 is indicative of the unigram match between the predicted and reference captions, which is a measure of adequacy, while BLEU-4 considers longer n-gram matches up to four words, which indicates fluency.
Validation BLEU Scores:
Analysis: Across the training steps, we can observe an increase in BLEU scores, indicating improvement in the model's captioning performance as training progresses. The scores tend to plateau, suggesting that the models have reached their performance capacity on the validation set.
Model Comparisons:
bert-) show worse performance than the plain models in terms of BLEU-1 and BLEU-4 scores. The increased vocabulary size with BERT (from ~10000 to ~30000) with BERT embeddings could be a contributing factor to this, as the complexity of the task increases with the larger vocabulary. Another potential reason could be that other words from BERTs larger vocabulary are predicted, which are not present in the reference captions, leading to lower BLEU scores.att-) demonstrate a modest improvement, indicating that attention potentially contributes to the adequacy and fluency of the generated captions.Test BLEU Scores:
plain-att-173 model achieves the highest BLEU-1 and BLEU-4 scores.Comparisons with Original Paper:
The original paper reported BLEU-1 to BLEU-4 scores of 67, 44.8, 29.9, and 19.5, respectively, on the Flickr8k dataset.
The best performing out of the 4 model variants (plain-att-173) achieves BLEU-1 to BLEU-4 scores of 65, 40, 23.4, and 13.3, respectively. Overall in comparison the performance degrades more with an increase in n-grams. This still looks quite promising, but the qualitative evaluation in the next section will provide a more comprehensive picture of the model's performance.
Based on the observations from the training and evaluation, I trained two additional models to investigate the impact of different hyperparameters.
plain-lr-0.001-180:
A standard embedding model with attention, teacher forcing, and advanced deep output, trained for 8 epochs with an increased learning rate of 0.001.
plain-bs-exp-178:
A standard embedding model with attention, teacher forcing, and advanced deep output, but trained for 35 epochs with a batch size of 128.
This last section focuses on the qualitative evaluation of the model's performance, including the generation of captions and attention visualizations.
I developed a script, generate_caption.py, which is designed to generate and visualize image captions. This script is an essential part of my project as it not only generates captions but also provides a visual representation of the attention mechanism at work.
Model Loading: The script can load models either from a specified path or from Weights & Biases (wandb).
Tokenization: Depending on the configuration, the script uses either BERT tokenizer or a custom word dictionary for tokenizing the captions.
Image Preprocessing: The input image is loaded and preprocessed to match the input format expected by the model. This includes resizing, normalization.
Caption Generation and Visualization:
Command-Line Interface (CLI): The script is designed to be run from the command line, with arguments for specifying the image path and model details. It can also be used in a Jupyter Notebook, as demonstrated in the next section.
The script is available at generate_caption.py.
I'll use the following table to qualitatively judge the model's performance on the test set.
| Criteria | Rating: 1 Point | Rating: 2 Points | Rating: 3 Points | Rating: 4 Points | Rating: 5 Points |
|---|---|---|---|---|---|
| Adequacy | Caption does not relate to the image content. | Caption vaguely relates to the image. | Caption covers basic elements in the image. | Caption covers most elements in the image. | Caption accurately covers all key elements of the image. |
| Semantic Correctness | Caption makes no logical sense. | Caption has significant logical flaws. | Caption is somewhat logical but has some inaccuracies. | Caption is logical with minor inaccuracies. | Caption is completely logical and accurate. |
| Object Detection | Fails to identify any objects. | Identifies less than 50% of objects correctly. | Identifies about 50% of objects correctly. | Identifies most objects correctly. | Identifies all objects correctly. |
| Color Detection | Fails to identify any colors. | Identifies less than 50% of colors correctly. | Identifies about 50% of colors correctly. | Identifies most colors correctly. | Identifies all colors correctly. |
| Attention Visualization | No attention visualization. | Attention visualization is not meaningful. | Attention visualization is somewhat meaningful. | Attention visualization is mostly meaningful. | Attention visualization is very meaningful. |
from generate_caption import generate_caption_visualization, load_model
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
WANDB_PROJECT = "yvokeller/show-attend-and-tell/"
WANDB_MODEL_FOLDER = "model/"
WANDB_MODEL_NAME_TEMPLATE = "model_vgg19_X.pth"
def load_model_from_checkpoint(run_id, checkpoint):
wandb_run = WANDB_PROJECT + run_id
wandb_model = WANDB_MODEL_FOLDER + WANDB_MODEL_NAME_TEMPLATE.replace("X", str(checkpoint))
return load_model(wandb_run=wandb_run, wandb_model=wandb_model)
def caption_images(img_paths, run_id, checkpoint, figsize=(9, 6), beam_size=3):
encoder, decoder, bert, model_path, model_config_path = load_model_from_checkpoint(run_id, checkpoint)
for img_path in img_paths:
generate_caption_visualization(img_path, encoder, decoder, model_path, model_config_path, beam_size=beam_size, figsize=figsize)
# Image Test Sets
own_images = ['data/mine/train.jpeg', 'data/mine/tashi.jpeg', 'data/mine/lake.jpeg']
test_images_flickr8k = [
'data/flickr8k/imgs/667626_18933d713e.jpg',
'data/flickr8k/imgs/280706862_14c30d734a.jpg',
'data/flickr8k/imgs/3072172967_630e9c69d0.jpg',
'data/flickr8k/imgs/2654514044_a70a6e2c21.jpg',
'data/flickr8k/imgs/311146855_0b65fdb169.jpg',
'data/flickr8k/imgs/2218609886_892dcd6915.jpg',
'data/flickr8k/imgs/2511019188_ca71775f2d.jpg',
'data/flickr8k/imgs/2435685480_a79d42e564.jpg',
'data/flickr8k/imgs/3482062809_3b694322c4.jpg'
]
https://wandb.ai/yvokeller/show-attend-and-tell/runs/8nu0sdou
Remarks
Qualitative Evaluation
| Criteria | Rating (1-5) |
|---|---|
| Adequacy | 4 |
| Semantic Correctness | 2 |
| Object Detection | 3 |
| Color Detection | 4 |
| Attention Visualization | 4 |
TOTAL 17
RATING 3.4
caption_images(test_images_flickr8k, '8nu0sdou', 8, beam_size=3)